{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "94f8a023",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/retrievers/recursive_retriever_nodes.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "025f3e20-aec9-491c-8c90-234aed406a25",
   "metadata": {},
   "source": [
    "# Recursive Retriever + Node References\n",
    "\n",
    "This guide shows how you can use recursive retrieval to traverse node relationships and fetch nodes based on \"references\".\n",
    "\n",
    "Node references are a powerful concept. When you first perform retrieval, you may want to retrieve the reference as opposed to the raw text. You can have multiple references point to the same node.\n",
    "\n",
    "In this guide we explore some different usages of node references:\n",
    "- **Chunk references**: Different chunk sizes referring to a bigger chunk\n",
    "- **Metadata references**: Summaries + Generated Questions referring to a bigger chunk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87ca1171",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install llama-index-llms-openai\n",
    "%pip install llama-index-readers-file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee583e89-a508-493e-b232-42e520ce19de",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%env OPENAI_API_KEY=YOUR_OPENAI_KEY"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "691e6b21",
   "metadata": {},
   "source": [
    "If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42164863",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install llama-index pypdf"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "273f38de-e79a-4ce2-ad4e-2c70afc33f34",
   "metadata": {},
   "source": [
    "## Load Data + Setup\n",
    "\n",
    "In this section we download the Llama 2 paper and create an initial set of nodes (chunk size 1024)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1eb829ef-b54b-4095-a832-6d1d115aa645",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Will not apply HSTS. The HSTS database must be a regular and non-world-writable file.\n",
      "ERROR: could not open HSTS store at '/home/loganm/.wget-hsts'. HSTS will be disabled.\n",
      "--2024-01-01 11:13:01--  https://arxiv.org/pdf/2307.09288.pdf\n",
      "Resolving arxiv.org (arxiv.org)... 151.101.3.42, 151.101.131.42, 151.101.67.42, ...\n",
      "Connecting to arxiv.org (arxiv.org)|151.101.3.42|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 13661300 (13M) [application/pdf]\n",
      "Saving to: ‘data/llama2.pdf’\n",
      "\n",
      "data/llama2.pdf     100%[===================>]  13.03M  27.3MB/s    in 0.5s    \n",
      "\n",
      "2024-01-01 11:13:02 (27.3 MB/s) - ‘data/llama2.pdf’ saved [13661300/13661300]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!mkdir -p 'data/'\n",
    "!wget --user-agent \"Mozilla\" \"https://arxiv.org/pdf/2307.09288.pdf\" -O \"data/llama2.pdf\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cd97455-5ff3-43ee-8222-f496ec234dc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from llama_index.readers.file import PDFReader\n",
    "from llama_index.core.response.notebook_utils import display_source_node\n",
    "from llama_index.core.retrievers import RecursiveRetriever\n",
    "from llama_index.core.query_engine import RetrieverQueryEngine\n",
    "from llama_index.core import VectorStoreIndex\n",
    "from llama_index.llms.openai import OpenAI\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a07c0e42-1ae8-4267-9355-6bb75323f82a",
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = PDFReader()\n",
    "docs0 = loader.load_data(file=Path(\"./data/llama2.pdf\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "493e5492-a6ae-4e3e-aa23-274c0605b165",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import Document\n",
    "\n",
    "doc_text = \"\\n\\n\".join([d.get_content() for d in docs0])\n",
    "docs = [Document(text=doc_text)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c2abcd3-6cae-49dd-8719-9b738d000652",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.node_parser import SentenceSplitter\n",
    "from llama_index.core.schema import IndexNode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91b997ae-9260-4ae7-af2f-0f8d38625d32",
   "metadata": {},
   "outputs": [],
   "source": [
    "node_parser = SentenceSplitter(chunk_size=1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cda44b0-fd27-4255-9aa7-08d358635772",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_nodes = node_parser.get_nodes_from_documents(docs)\n",
    "# set node ids to be a constant\n",
    "for idx, node in enumerate(base_nodes):\n",
    "    node.id_ = f\"node-{idx}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38e47623-b67d-45d6-9b24-33ba84719f1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.embeddings import resolve_embed_model\n",
    "\n",
    "embed_model = resolve_embed_model(\"local:BAAI/bge-small-en\")\n",
    "llm = OpenAI(model=\"gpt-3.5-turbo\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f43ebab2-fc46-41ea-8a92-9148994d793f",
   "metadata": {},
   "source": [
    "## Baseline Retriever\n",
    "\n",
    "Define a baseline retriever that simply fetches the top-k raw text nodes by embedding similarity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "704fb3da-710e-4ad9-b630-565911917f0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_index = VectorStoreIndex(base_nodes, embed_model=embed_model)\n",
    "base_retriever = base_index.as_retriever(similarity_top_k=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "160c339b-601a-486b-9e17-dd6cc9f133ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "retrievals = base_retriever.retrieve(\n",
    "    \"Can you tell me about the key concepts for safety finetuning\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "632610f3-c8f2-440a-ab27-5ca7d65f882a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "**Node ID:** node-26<br>**Similarity:** 0.8581930837671874<br>**Text:** AsLLMsareintegratedanddeployed,welookforwardto\n",
       "continuing research that will amplify their potential for positive impact on these important social issues.\n",
       "4.2 Safety Fine-Tuning\n",
       "In this section, we describe our approach to safety fine-tuning, including safety categories, annotation\n",
       "guidelines,andthetechniquesweusetomitigatesafetyrisks. Weemployaprocesssimilartothegeneral\n",
       "fine-tuning methods as described in Section 3, with some notable differences related to safety concerns.\n",
       "Specifically, we use the following techniques in safety fine-tuning:\n",
       "1.Supervised Safety Fine-Tuning : We initialize by gathering adversarial prompts and safe demonstra-\n",
       "tions that are then included in the general supervised fine-tuning process (Section 3.1). This teaches\n",
       "themodeltoalignwithoursafetyguidelinesevenbeforeRLHF,andthuslaysthefoundationfor\n",
       "high-quality human preference data annotation.\n",
       "2.Safety RLHF : Subsequently, we integrate safety in the general RLHF pipeline described in Sec-\n",
       "tion 3.2.2. This includes training a safety-specific reward model and gathering more challenging\n",
       "adversarial prompts for rejection sampling style fine-tuning and PPO optimization.\n",
       "3.SafetyContextDistillation : Finally,werefineourRLHFpipelinewithcontextdistillation(Askell\n",
       "etal.,2021b). Thisinvolvesgeneratingsafermodelresponsesbyprefixingapromptwithasafety\n",
       "preprompt, e.g., “You are a safe and responsible assistant,” and then fine-tuning the model on the safer\n",
       "responses without the preprompt, which essentially distill...<br>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/markdown": [
       "**Node ID:** node-33<br>**Similarity:** 0.8442806881920294<br>**Text:** Examples of evaluation prompts and responses can be found in Appendix A.4.6.\n",
       "We then asked raters to judge models for safety violations on a five-point Likert scale with the following\n",
       "definitions:\n",
       "•5- No safety violations and very helpful\n",
       "•4- No safety violations and only minor non-safety issues\n",
       "•3- No safety violations but not helpful or other major non-safety issues\n",
       "•2- Mild or moderate safety violations\n",
       "29\n",
       "\n",
       "(a)Overall violation percentage.\n",
       " (b)Overall safety and helpfulness mean rating.\n",
       "Figure17: Overallsafetymeasures. Left: Llama 2-Chat haslowviolationpercentageoverallacrossmodel\n",
       "sizes.Right: Llama 2-Chat has high safety and helpfulness mean rating overall across model sizes. It is\n",
       "important to note that these results are subject to limitations of the prompt set, subjectivity of the review\n",
       "guidelines, and subjectivity of individual raters.\n",
       "•1- Severe safety violations\n",
       "We consider a rating of 1 or 2 as violation and use violation percentage as our main evaluation metric, with\n",
       "themeanratingasasupplement. Eachexampleisannotatedbythreeannotatorsandwetakethemajority\n",
       "votetodetermineiftheresponseisviolatingornot. WeusedGwet’sAC1/2statistictomeasureinter-rater\n",
       "reliability(IRR)asinthehelpfulnesshumanevaluation. TheIRRscoresrangefrom 0.70to0.95depending\n",
       "on the annotation batch, indicating a high degree of agreement among annotators on safety assessments.\n",
       "OnLlama 2-Chat annotations, the average IRR is 0.92according to Gwet’s AC2 measure. We see lower IRR\n",
       "scoresonbatcheswherethemo...<br>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for n in retrievals:\n",
    "    display_source_node(n, source_length=1500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96dd8a01-1cae-4614-beab-5b5e0434fefe",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_engine_base = RetrieverQueryEngine.from_args(base_retriever, llm=llm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82ae66ff-7d12-45c8-9b1a-adb20bd3c7ea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The key concepts for safety fine-tuning include supervised safety fine-tuning, safety RLHF (Reinforcement Learning from Human Feedback), and safety context distillation. In supervised safety fine-tuning, adversarial prompts and safe demonstrations are gathered and included in the general supervised fine-tuning process. This helps the model align with safety guidelines and lays the foundation for high-quality human preference data annotation. Safety RLHF involves integrating safety in the general RLHF pipeline, which includes training a safety-specific reward model and gathering more challenging adversarial prompts for rejection sampling style fine-tuning and PPO (Proximal Policy Optimization) optimization. Safety context distillation is the final step, where the RLHF pipeline is refined with context distillation. This involves generating safer model responses by prefixing a prompt with a safety preprompt and then fine-tuning the model on the safer responses without the preprompt.\n"
     ]
    }
   ],
   "source": [
    "response = query_engine_base.query(\n",
    "    \"Can you tell me about the key concepts for safety finetuning\"\n",
    ")\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5431df3-d255-4492-bce4-bbebde6f2306",
   "metadata": {},
   "source": [
    "## Chunk References: Smaller Child Chunks Referring to Bigger Parent Chunk\n",
    "\n",
    "In this usage example, we show how to build a graph of smaller chunks pointing to bigger parent chunks.\n",
    "\n",
    "During query-time, we retrieve smaller chunks, but we follow references to bigger chunks. This allows us to have more context for synthesis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49c784d8-71e6-42bc-84d9-a2aea4217b8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "sub_chunk_sizes = [128, 256, 512]\n",
    "sub_node_parsers = [\n",
    "    SentenceSplitter(chunk_size=c, chunk_overlap=20) for c in sub_chunk_sizes\n",
    "]\n",
    "\n",
    "all_nodes = []\n",
    "for base_node in base_nodes:\n",
    "    for n in sub_node_parsers:\n",
    "        sub_nodes = n.get_nodes_from_documents([base_node])\n",
    "        sub_inodes = [\n",
    "            IndexNode.from_text_node(sn, base_node.node_id) for sn in sub_nodes\n",
    "        ]\n",
    "        all_nodes.extend(sub_inodes)\n",
    "\n",
    "    # also add original node to node\n",
    "    original_node = IndexNode.from_text_node(base_node, base_node.node_id)\n",
    "    all_nodes.append(original_node)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d614088-b122-40ad-811a-29cc0c2a295e",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_nodes_dict = {n.node_id: n for n in all_nodes}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a44ef2d5-0342-4073-831f-f35dd6f04dc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "vector_index_chunk = VectorStoreIndex(all_nodes, embed_model=embed_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c06af99f-02be-4055-a6ea-3071ffe8fc8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "vector_retriever_chunk = vector_index_chunk.as_retriever(similarity_top_k=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c7c5e43-45b5-42d6-afc5-cb81ed3cb211",
   "metadata": {},
   "outputs": [],
   "source": [
    "retriever_chunk = RecursiveRetriever(\n",
    "    \"vector\",\n",
    "    retriever_dict={\"vector\": vector_retriever_chunk},\n",
    "    node_dict=all_nodes_dict,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e9f7bcb-5442-4d2d-a7eb-814b68ebb45c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1;3;34mRetrieving with query id None: Can you tell me about the key concepts for safety finetuning\n",
      "\u001b[0m\u001b[1;3;38;5;200mRetrieved node with id, entering: node-26\n",
      "\u001b[0m\u001b[1;3;34mRetrieving with query id node-26: Can you tell me about the key concepts for safety finetuning\n",
      "\u001b[0m\u001b[1;3;38;5;200mRetrieved node with id, entering: node-1\n",
      "\u001b[0m\u001b[1;3;34mRetrieving with query id node-1: Can you tell me about the key concepts for safety finetuning\n",
      "\u001b[0m"
     ]
    },
    {
     "data": {
      "text/markdown": [
       "**Node ID:** node-26<br>**Similarity:** 0.8809071991986446<br>**Text:** AsLLMsareintegratedanddeployed,welookforwardto\n",
       "continuing research that will amplify their potential for positive impact on these important social issues.\n",
       "4.2 Safety Fine-Tuning\n",
       "In this section, we describe our approach to safety fine-tuning, including safety categories, annotation\n",
       "guidelines,andthetechniquesweusetomitigatesafetyrisks. Weemployaprocesssimilartothegeneral\n",
       "fine-tuning methods as described in Section 3, with some notable differences related to safety concerns.\n",
       "Specifically, we use the following techniques in safety fine-tuning:\n",
       "1.Supervised Safety Fine-Tuning : We initialize by gathering adversarial prompts and safe demonstra-\n",
       "tions that are then included in the general supervised fine-tuning process (Section 3.1). This teaches\n",
       "themodeltoalignwithoursafetyguidelinesevenbeforeRLHF,andthuslaysthefoundationfor\n",
       "high-quality human preference data annotation.\n",
       "2.Safety RLHF : Subsequently, we integrate safety in the general RLHF pipeline described in Sec-\n",
       "tion 3.2.2. This includes training a safety-specific reward model and gathering more challenging\n",
       "adversarial prompts for rejection sampling style fine-tuning and PPO optimization.\n",
       "3.SafetyContextDistillation : Finally,werefineourRLHFpipelinewithcontextdistillation(Askell\n",
       "etal.,2021b). Thisinvolvesgeneratingsafermodelresponsesbyprefixingapromptwithasafety\n",
       "preprompt, e.g., “You are a safe and responsible assistant,” and then fine-tuning the model on the safer\n",
       "responses without the preprompt, which essentially distillsthe safety preprompt (context) into the\n",
       "model. Weuseatargetedapproachthatallowsoursafetyrewardmodeltochoosewhethertouse\n",
       "context distillation for each sample.\n",
       "4.2.1 Safety Categories and Annotation Guidelines\n",
       "Based on limitations of LLMs known from prior work, we design instructions for our annotation team to\n",
       "createadversarialpromptsalongtwodimensions: a riskcategory ,orpotentialtopicaboutwhichtheLLM\n",
       "couldproduceunsafecontent;andan attackvector ,orquestionstyletocoverdifferentvarietiesofprompts\n",
       "...<br>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/markdown": [
       "**Node ID:** node-1<br>**Similarity:** 0.8744334039911964<br>**Text:** . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 9\n",
       "3.2 Reinforcement Learning with Human Feedback (RLHF) . . . . . . . . . . . . . . . . . . . . . 9\n",
       "3.3 System Message for Multi-Turn Consistency . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 16\n",
       "3.4 RLHF Results . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 17\n",
       "4 Safety 20\n",
       "4.1 Safety in Pretraining . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 20\n",
       "4.2 Safety Fine-Tuning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 23\n",
       "4.3 Red Teaming . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 28\n",
       "4.4 Safety Evaluation of Llama 2-Chat . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29\n",
       "5 Discussion 32\n",
       "5.1 Learnings and Observations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 32\n",
       "5.2 Limitations and Ethical Considerations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 34\n",
       "5.3 Responsible Release Strategy . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 35\n",
       "6 Related Work 35\n",
       "7 Conclusion 36\n",
       "A Appendix 46\n",
       "A.1 Contributions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .<br>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "nodes = retriever_chunk.retrieve(\n",
    "    \"Can you tell me about the key concepts for safety finetuning\"\n",
    ")\n",
    "for node in nodes:\n",
    "    display_source_node(node, source_length=2000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "411f26ad-d13b-4858-938e-efcfa899e8cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_engine_chunk = RetrieverQueryEngine.from_args(retriever_chunk, llm=llm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cd98366-0d5f-4d04-87cd-b811990b7485",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1;3;34mRetrieving with query id None: Can you tell me about the key concepts for safety finetuning\n",
      "\u001b[0m\u001b[1;3;38;5;200mRetrieved node with id, entering: node-26\n",
      "\u001b[0m\u001b[1;3;34mRetrieving with query id node-26: Can you tell me about the key concepts for safety finetuning\n",
      "\u001b[0m\u001b[1;3;38;5;200mRetrieved node with id, entering: node-1\n",
      "\u001b[0m\u001b[1;3;34mRetrieving with query id node-1: Can you tell me about the key concepts for safety finetuning\n",
      "\u001b[0mThe key concepts for safety fine-tuning include supervised safety fine-tuning, safety RLHF (Reinforcement Learning with Human Feedback), and safety context distillation. Supervised safety fine-tuning involves gathering adversarial prompts and safe demonstrations to teach the model to align with safety guidelines. Safety RLHF integrates safety into the general RLHF pipeline by training a safety-specific reward model and gathering challenging adversarial prompts for rejection sampling style fine-tuning and PPO optimization. Safety context distillation involves generating safer model responses by prefixing a prompt with a safety preprompt and fine-tuning the model on the safer responses without the preprompt. These techniques aim to mitigate safety risks and improve the model's ability to provide safe and responsible responses.\n"
     ]
    }
   ],
   "source": [
    "response = query_engine_chunk.query(\n",
    "    \"Can you tell me about the key concepts for safety finetuning\"\n",
    ")\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bcc7379-c077-40b7-ba4e-f47f80def0c7",
   "metadata": {},
   "source": [
    "## Metadata References: Summaries + Generated Questions referring to a bigger chunk\n",
    "\n",
    "In this usage example, we show how to define additional context that references the source node.\n",
    "\n",
    "This additional context includes summaries as well as generated questions.\n",
    "\n",
    "During query-time, we retrieve smaller chunks, but we follow references to bigger chunks. This allows us to have more context for synthesis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e3c4f8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24e40c5e-4868-487f-aaf4-f333aa4bda66",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.node_parser import SentenceSplitter\n",
    "from llama_index.core.schema import IndexNode\n",
    "from llama_index.core.extractors import (\n",
    "    SummaryExtractor,\n",
    "    QuestionsAnsweredExtractor,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c5d6f87-790e-4b82-abb2-cc6944678b00",
   "metadata": {},
   "outputs": [],
   "source": [
    "extractors = [\n",
    "    SummaryExtractor(summaries=[\"self\"], show_progress=True),\n",
    "    QuestionsAnsweredExtractor(questions=5, show_progress=True),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e47c706c-940e-499d-b742-eaf09a230b0d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 93/93 [01:13<00:00,  1.27it/s]\n",
      "100%|██████████| 93/93 [00:49<00:00,  1.88it/s]\n"
     ]
    }
   ],
   "source": [
    "# run metadata extractor across base nodes, get back dictionaries\n",
    "node_to_metadata = {}\n",
    "for extractor in extractors:\n",
    "    metadata_dicts = extractor.extract(base_nodes)\n",
    "    for node, metadata in zip(base_nodes, metadata_dicts):\n",
    "        if node.node_id not in node_to_metadata:\n",
    "            node_to_metadata[node.node_id] = metadata\n",
    "        else:\n",
    "            node_to_metadata[node.node_id].update(metadata)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2873327d-420a-4778-a83b-6fdf7aa21bcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# cache metadata dicts\n",
    "def save_metadata_dicts(path, data):\n",
    "    with open(path, \"w\") as fp:\n",
    "        json.dump(data, fp)\n",
    "\n",
    "\n",
    "def load_metadata_dicts(path):\n",
    "    with open(path, \"r\") as fp:\n",
    "        data = json.load(fp)\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e318efb2-9afa-4414-b37f-71738d73d01d",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_metadata_dicts(\"data/llama2_metadata_dicts.json\", node_to_metadata)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4edce99f-8a96-4539-95e7-62aeeabb2ce9",
   "metadata": {},
   "outputs": [],
   "source": [
    "metadata_dicts = load_metadata_dicts(\"data/llama2_metadata_dicts.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f18d2109-5fcb-4fd5-b147-23897fed8787",
   "metadata": {},
   "outputs": [],
   "source": [
    "# all nodes consists of source nodes, along with metadata\n",
    "import copy\n",
    "\n",
    "all_nodes = copy.deepcopy(base_nodes)\n",
    "for node_id, metadata in node_to_metadata.items():\n",
    "    for val in metadata.values():\n",
    "        all_nodes.append(IndexNode(text=val, index_id=node_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f90ada6-0969-40cc-a4ec-3579b4900cdd",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_nodes_dict = {n.node_id: n for n in all_nodes}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22abc768-83d5-41d0-84f0-533899c76894",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Load index into vector index\n",
    "from llama_index.core import VectorStoreIndex\n",
    "from llama_index.llms.openai import OpenAI\n",
    "\n",
    "llm = OpenAI(model=\"gpt-3.5-turbo\")\n",
    "\n",
    "vector_index_metadata = VectorStoreIndex(all_nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d53938a-1322-41b1-ad11-169b13b9805a",
   "metadata": {},
   "outputs": [],
   "source": [
    "vector_retriever_metadata = vector_index_metadata.as_retriever(\n",
    "    similarity_top_k=2\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37ae791f-c183-4ad4-9a3a-253288ded5a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "retriever_metadata = RecursiveRetriever(\n",
    "    \"vector\",\n",
    "    retriever_dict={\"vector\": vector_retriever_metadata},\n",
    "    node_dict=all_nodes_dict,\n",
    "    verbose=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cd85685-19eb-44cc-ad27-1d163eaddad6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "**Node ID:** node-26<br>**Similarity:** 0.8727061238826861<br>**Text:** AsLLMsareintegratedanddeployed,welookforwardto\n",
       "continuing research that will amplify their potential for positive impact on these important social issues.\n",
       "4.2 Safety Fine-Tuning\n",
       "In this section, we describe our approach to safety fine-tuning, including safety categories, annotation\n",
       "guidelines,andthetechniquesweusetomitigatesafetyrisks. Weemployaprocesssimilartothegeneral\n",
       "fine-tuning methods as described in Section 3, with some notable differences related to safety concerns.\n",
       "Specifically, we use the following techniques in safety fine-tuning:\n",
       "1.Supervised Safety Fine-Tuning : We initialize by gathering adversarial prompts and safe demonstra-\n",
       "tions that are then included in the general supervised fine-tuning process (Section 3.1). This teaches\n",
       "themodeltoalignwithoursafetyguidelinesevenbeforeRLHF,andthuslaysthefoundationfor\n",
       "high-quality human preference data annotation.\n",
       "2.Safety RLHF : Subsequently, we integrate safety in the general RLHF pipeline described in Sec-\n",
       "tion 3.2.2. This includes training a safety-specific reward model and gathering more challenging\n",
       "adversarial prompts for rejection sampling style fine-tuning and PPO optimization.\n",
       "3.SafetyContextDistillation : Finally,werefineourRLHFpipelinewithcontextdistillation(Askell\n",
       "etal.,2021b). Thisinvolvesgeneratingsafermodelresponsesbyprefixingapromptwithasafety\n",
       "preprompt, e.g., “You are a safe and responsible assistant,” and then fine-tuning the model on the safer\n",
       "responses without the preprompt, which essentially distillsthe safety preprompt (context) into the\n",
       "model. Weuseatargetedapproachthatallowsoursafetyrewardmodeltochoosewhethertouse\n",
       "context distillation for each sample.\n",
       "4.2.1 Safety Categories and Annotation Guidelines\n",
       "Based on limitations of LLMs known from prior work, we design instructions for our annotation team to\n",
       "createadversarialpromptsalongtwodimensions: a riskcategory ,orpotentialtopicaboutwhichtheLLM\n",
       "couldproduceunsafecontent;andan attackvector ,orquestionstyletocoverdifferentvarietiesofprompts\n",
       "...<br>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/markdown": [
       "**Node ID:** node-26<br>**Similarity:** 0.8586079224453517<br>**Text:** AsLLMsareintegratedanddeployed,welookforwardto\n",
       "continuing research that will amplify their potential for positive impact on these important social issues.\n",
       "4.2 Safety Fine-Tuning\n",
       "In this section, we describe our approach to safety fine-tuning, including safety categories, annotation\n",
       "guidelines,andthetechniquesweusetomitigatesafetyrisks. Weemployaprocesssimilartothegeneral\n",
       "fine-tuning methods as described in Section 3, with some notable differences related to safety concerns.\n",
       "Specifically, we use the following techniques in safety fine-tuning:\n",
       "1.Supervised Safety Fine-Tuning : We initialize by gathering adversarial prompts and safe demonstra-\n",
       "tions that are then included in the general supervised fine-tuning process (Section 3.1). This teaches\n",
       "themodeltoalignwithoursafetyguidelinesevenbeforeRLHF,andthuslaysthefoundationfor\n",
       "high-quality human preference data annotation.\n",
       "2.Safety RLHF : Subsequently, we integrate safety in the general RLHF pipeline described in Sec-\n",
       "tion 3.2.2. This includes training a safety-specific reward model and gathering more challenging\n",
       "adversarial prompts for rejection sampling style fine-tuning and PPO optimization.\n",
       "3.SafetyContextDistillation : Finally,werefineourRLHFpipelinewithcontextdistillation(Askell\n",
       "etal.,2021b). Thisinvolvesgeneratingsafermodelresponsesbyprefixingapromptwithasafety\n",
       "preprompt, e.g., “You are a safe and responsible assistant,” and then fine-tuning the model on the safer\n",
       "responses without the preprompt, which essentially distillsthe safety preprompt (context) into the\n",
       "model. Weuseatargetedapproachthatallowsoursafetyrewardmodeltochoosewhethertouse\n",
       "context distillation for each sample.\n",
       "4.2.1 Safety Categories and Annotation Guidelines\n",
       "Based on limitations of LLMs known from prior work, we design instructions for our annotation team to\n",
       "createadversarialpromptsalongtwodimensions: a riskcategory ,orpotentialtopicaboutwhichtheLLM\n",
       "couldproduceunsafecontent;andan attackvector ,orquestionstyletocoverdifferentvarietiesofprompts\n",
       "...<br>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "nodes = retriever_metadata.retrieve(\n",
    "    \"Can you tell me about the key concepts for safety finetuning\"\n",
    ")\n",
    "for node in nodes:\n",
    "    display_source_node(node, source_length=2000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5285854a-69a6-4bc4-a2a5-1004cc790a63",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_engine_metadata = RetrieverQueryEngine.from_args(\n",
    "    retriever_metadata, llm=llm\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e0ada5c-9a83-4517-bbb7-899d4415d68a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The key concepts for safety fine-tuning include supervised safety fine-tuning, safety RLHF (Reinforcement Learning from Human Feedback), and safety context distillation. Supervised safety fine-tuning involves gathering adversarial prompts and safe demonstrations to train the model to align with safety guidelines. Safety RLHF integrates safety into the RLHF pipeline by training a safety-specific reward model and gathering challenging adversarial prompts for fine-tuning and optimization. Safety context distillation involves generating safer model responses by prefixing a prompt with a safety preprompt and fine-tuning the model on the safer responses without the preprompt. These concepts are used to mitigate safety risks and improve the model's ability to produce safe and helpful responses.\n"
     ]
    }
   ],
   "source": [
    "response = query_engine_metadata.query(\n",
    "    \"Can you tell me about the key concepts for safety finetuning\"\n",
    ")\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9973bdca-d179-47d6-bd96-2631b36e1d94",
   "metadata": {},
   "source": [
    "## Evaluation\n",
    "\n",
    "We evaluate how well our recursive retrieval + node reference methods work. We evaluate both chunk references as well as metadata references. We use embedding similarity lookup to retrieve the reference nodes.\n",
    "\n",
    "We compare both methods against a baseline retriever where we fetch the raw nodes directly.\n",
    "\n",
    "In terms of metrics, we evaluate using both hit-rate and MRR."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b3a30b7-2eb2-4eae-b0b9-1d4ec26ac915",
   "metadata": {},
   "source": [
    "### Dataset Generation\n",
    "\n",
    "We first generate a dataset of questions from the set of text chunks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fe8ae8a-a2b2-4515-bcff-1145e14ede3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.evaluation import (\n",
    "    generate_question_context_pairs,\n",
    "    EmbeddingQAFinetuneDataset,\n",
    ")\n",
    "from llama_index.llms.openai import OpenAI\n",
    "\n",
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eef1b43d-996b-4b0a-becb-1cec08d9f8c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 93/93 [02:08<00:00,  1.38s/it]\n"
     ]
    }
   ],
   "source": [
    "eval_dataset = generate_question_context_pairs(\n",
    "    base_nodes, OpenAI(model=\"gpt-3.5-turbo\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd3e2507-9157-48a5-909b-18eeb9ec01d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_dataset.save_json(\"data/llama2_eval_dataset.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "611f07af-2006-4158-8dc6-59d11a269c8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# optional\n",
    "eval_dataset = EmbeddingQAFinetuneDataset.from_json(\n",
    "    \"data/llama2_eval_dataset.json\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb4782a6-f3da-453f-93be-7683ed15b508",
   "metadata": {},
   "source": [
    "### Compare Results\n",
    "\n",
    "We run evaluations on each of the retrievers to measure hit rate and MRR.\n",
    "\n",
    "We find that retrievers with node references (either chunk or metadata) tend to perform better than retrieving the raw chunks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87798866-11bc-4f7f-b8aa-0a023309492f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from llama_index.core.evaluation import (\n",
    "    RetrieverEvaluator,\n",
    "    get_retrieval_results_df,\n",
    ")\n",
    "\n",
    "# set vector retriever similarity top k to higher\n",
    "top_k = 10\n",
    "\n",
    "\n",
    "def display_results(names, results_arr):\n",
    "    \"\"\"Display results from evaluate.\"\"\"\n",
    "\n",
    "    hit_rates = []\n",
    "    mrrs = []\n",
    "    for name, eval_results in zip(names, results_arr):\n",
    "        metric_dicts = []\n",
    "        for eval_result in eval_results:\n",
    "            metric_dict = eval_result.metric_vals_dict\n",
    "            metric_dicts.append(metric_dict)\n",
    "        results_df = pd.DataFrame(metric_dicts)\n",
    "\n",
    "        hit_rate = results_df[\"hit_rate\"].mean()\n",
    "        mrr = results_df[\"mrr\"].mean()\n",
    "        hit_rates.append(hit_rate)\n",
    "        mrrs.append(mrr)\n",
    "\n",
    "    final_df = pd.DataFrame(\n",
    "        {\"retrievers\": names, \"hit_rate\": hit_rates, \"mrr\": mrrs}\n",
    "    )\n",
    "    display(final_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6d142c6-0374-43ec-af31-e02d246bd815",
   "metadata": {},
   "outputs": [],
   "source": [
    "vector_retriever_chunk = vector_index_chunk.as_retriever(\n",
    "    similarity_top_k=top_k\n",
    ")\n",
    "retriever_chunk = RecursiveRetriever(\n",
    "    \"vector\",\n",
    "    retriever_dict={\"vector\": vector_retriever_chunk},\n",
    "    node_dict=all_nodes_dict,\n",
    "    verbose=True,\n",
    ")\n",
    "retriever_evaluator = RetrieverEvaluator.from_metric_names(\n",
    "    [\"mrr\", \"hit_rate\"], retriever=retriever_chunk\n",
    ")\n",
    "# try it out on an entire dataset\n",
    "results_chunk = await retriever_evaluator.aevaluate_dataset(\n",
    "    eval_dataset, show_progress=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae448fe7-3a66-45a6-8e8e-6ed3950e61b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "vector_retriever_metadata = vector_index_metadata.as_retriever(\n",
    "    similarity_top_k=top_k\n",
    ")\n",
    "retriever_metadata = RecursiveRetriever(\n",
    "    \"vector\",\n",
    "    retriever_dict={\"vector\": vector_retriever_metadata},\n",
    "    node_dict=all_nodes_dict,\n",
    "    verbose=True,\n",
    ")\n",
    "retriever_evaluator = RetrieverEvaluator.from_metric_names(\n",
    "    [\"mrr\", \"hit_rate\"], retriever=retriever_metadata\n",
    ")\n",
    "# try it out on an entire dataset\n",
    "results_metadata = await retriever_evaluator.aevaluate_dataset(\n",
    "    eval_dataset, show_progress=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d3fc029-7ccc-4ec4-b391-b7b86744b5d8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 194/194 [00:09<00:00, 19.86it/s]\n"
     ]
    }
   ],
   "source": [
    "base_retriever = base_index.as_retriever(similarity_top_k=top_k)\n",
    "retriever_evaluator = RetrieverEvaluator.from_metric_names(\n",
    "    [\"mrr\", \"hit_rate\"], retriever=base_retriever\n",
    ")\n",
    "# try it out on an entire dataset\n",
    "results_base = await retriever_evaluator.aevaluate_dataset(\n",
    "    eval_dataset, show_progress=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ef0cd73-b1ad-4ec6-931f-357d2ceebd65",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>retrievers</th>\n",
       "      <th>hit_rate</th>\n",
       "      <th>mrr</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Base Retriever</td>\n",
       "      <td>0.778351</td>\n",
       "      <td>0.563103</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Retriever (Chunk References)</td>\n",
       "      <td>0.896907</td>\n",
       "      <td>0.691114</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Retriever (Metadata References)</td>\n",
       "      <td>0.891753</td>\n",
       "      <td>0.718440</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                        retrievers  hit_rate       mrr\n",
       "0                   Base Retriever  0.778351  0.563103\n",
       "1     Retriever (Chunk References)  0.896907  0.691114\n",
       "2  Retriever (Metadata References)  0.891753  0.718440"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "full_results_df = get_retrieval_results_df(\n",
    "    [\n",
    "        \"Base Retriever\",\n",
    "        \"Retriever (Chunk References)\",\n",
    "        \"Retriever (Metadata References)\",\n",
    "    ],\n",
    "    [results_base, results_chunk, results_metadata],\n",
    ")\n",
    "display(full_results_df)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
